#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Author ：hhx
@Date ：2022/5/29 10:38 
@Description ：无监督模型
"""
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import os
from sklearn.cluster import KMeans
import copy

device = 'cpu'
# device = torch.device("cuda" if torch.cuda.is_available() else 'cpu')
torch.set_default_tensor_type(torch.DoubleTensor)
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"


class Vertex:
    def __init__(self, key):
        self.id = key
        self.connectedTo = {}

    def addNeighbor(self, nbr, weight=0):
        self.connectedTo[nbr] = weight

    def __str__(self):
        return str(self.id) + ' connectedTo: ' + str([x.id for x in self.connectedTo])

    def getConnections(self):
        return self.connectedTo.keys()

    def getID(self):
        return self.id

    def getWeight(self, nbr):
        return self.connectedTo[nbr]


class Graph(object):
    """
    以邻接矩阵为存储结构创建无向网
    """

    def __init__(self, kind):
        # 图的类型: 无向图, 有向图, 无向网, 有向网
        # kind: Undigraph, Digraph, Undinetwork, Dinetwork,
        self.kind = kind
        # 顶点表
        self.vertexs = []
        # 边表, 即邻接矩阵, 是个二维的
        self.arcs = []
        # 当前顶点数
        self.vexnum = 0
        # 当前边(弧)数
        self.arcnum = 0

    def CreateGraph(self, vertex_list, edge_list):
        """
        创建图
        :param vertex_list: 顶点列表
        :param edge_list: 边列表
        :return:
        """
        self.vexnum = len(vertex_list)
        self.arcnum = len(edge_list)
        for vertex in vertex_list:
            vertex = Vertex(vertex)
            # 顶点列表
            self.vertexs.append(vertex)
            # 邻接矩阵, 初始化为无穷
            self.arcs.append([float('inf')] * self.vexnum)
        for edge in edge_list:
            ivertex = self.LocateVertex(edge[0])
            jvertex = self.LocateVertex(edge[1])
            weight = edge[2]
            self.InsertArc(ivertex, jvertex, weight)

    def LocateVertex(self, vertex):
        """
        定位顶点在邻接表中的位置
        :param vertex:
        :return:
        """
        index = 0
        while index < self.vexnum:
            if self.vertexs[index].id == vertex:
                return index
            else:
                index += 1

    def InsertArc(self, ivertex, jvertex, weight):
        """
        创建邻接矩阵
        :param ivertex:
        :param jvertex:
        :param weight:
        :return:
        """
        if self.kind == 'Undinetwork':
            self.arcs[ivertex][jvertex] = weight
            self.arcs[jvertex][ivertex] = weight

    def GetMin(self, closedge):
        """
        找到当前closedge中权值最小的边
        :param closedge:
        :return:
        """
        index = 0
        vertex = 0
        minweight = float('inf')
        while index < self.vexnum:
            if closedge[index][1] != 0 and closedge[index][1] < minweight:
                minweight = closedge[index][1]
                vertex = index
            index += 1
        return vertex

    def Prim(self, start_vertex):
        k = self.LocateVertex(start_vertex)
        closedge = []
        arc = []
        for index in range(self.vexnum):
            # 下标权值, 初始化
            closedge.append([k, self.arcs[k][index]])
        # 将起始点加入到U中
        closedge[k][1] = 0
        index = 1
        while index < self.vexnum:
            # 找到了与下标为k相连的最小边
            minedge = self.GetMin(closedge)
            # 将当前最小权值的边加入到最小生成树arc
            arc.append([self.vertexs[closedge[minedge][0]].id, self.vertexs[minedge].id, closedge[minedge][1]])
            # 将最小边权值置为0, 即将顶点v加入U中, 表示该顶点已经在最小生成树内
            closedge[minedge][1] = 0
            i = 0
            # 重新选择权值最小边
            while i < self.vexnum:
                if self.arcs[minedge][i] < closedge[i][1]:
                    # 更新 最小边的权值及下标
                    closedge[i] = [minedge, self.arcs[minedge][i]]
                i += 1
            index += 1
        return arc

    def AddEdges(self):
        """
        将连通网中的边加入到列表AddEdges中
        :return:
        """
        edges = []
        i = 0
        while i < self.vexnum:
            j = 0
            while j < self.vexnum:
                if self.arcs[i][j] != float('inf'):
                    edges.append([self.vertexs[i].id, self.vertexs[j].id, self.arcs[i][j]])
                j += 1
            i += 1
        # 按权重从小到大进行排序
        return sorted(edges, key=lambda item: item[2])

    def Kruskal(self):
        edges = self.AddEdges()
        flags = []
        for index in range(self.vexnum):
            flags.append(index)
        index = 0
        while index < len(edges):
            ivertex = self.LocateVertex(edges[index][0])
            jvertex = self.LocateVertex(edges[index][1])
            if flags[ivertex] != flags[jvertex]:
                # 两个顶点不属于同一连通分量
                # 找到它们各自的连通分量的序号
                iflag = flags[ivertex]
                jflag = flags[jvertex]
                # 它们两个如何合并, 找找flags有没有与之相同的
                limit = 0
                while limit < self.vexnum:
                    if flags[limit] == jflag:
                        # 将j和i的连通序号设置相同, 表示它俩是连通的
                        flags[limit] = iflag
                    limit += 1
                # index就要放这里, 因为删除边后edges长度就会减少1
                index += 1
            else:
                # 已经连通了, 即加入这条边就构成了环
                # 删除这条边
                edges.pop(index)
        return edges


def formatnum(n):
    if n < 10:
        return '0' + str(n)
    return str(n)


def CreateVertex(nums):
    """创建节点"""
    return [i for i in range(nums)]


def P(outputs1, outputs2):
    """皮尔逊距离"""
    x = np.vstack((outputs1, outputs2))
    return np.corrcoef(x)[0][-1]


def GetFeatures(data):
    n = data.shape[0]
    Result_P = np.zeros([n, n])
    for i in range(n):
        for j in range(i, n):
            P_value = P(data[i], data[j])
            Result_P[i, j] = P_value
            Result_P[j, i] = P_value
    return 1 - Result_P

def CreateEdges_new(VertexList, Features):
    EdgeList = []
    for i in range(len(VertexList)):
        for j in range(i+1, len(VertexList)):
            print(i,j)
            EdgeList.append((VertexList[i], VertexList[j], Features[i, j]))
    return EdgeList


def CreateEdges(VertexList, Features):
    """创建边"""
    EdgeList = []
    # 创建相临边
    for i in range(len(VertexList) - 1):
        EdgeList.append((VertexList[i], VertexList[i + 1], Features[i, i + 1]))

    # 创建周期边
    i = 0
    while True:
        if i + 12 == len(VertexList):
            break
        EdgeList.append((VertexList[i], VertexList[i + 12], Features[i, i + 12]))  # 12为一周期
        i += 1

    # 创建噪音边
    i = 0
    while True:
        if i + 2 == len(VertexList):
            break
        EdgeList.append((VertexList[i], VertexList[i + 2], Features[i, i + 2]))  # 12为一周期
        i += 1

    return EdgeList


def is_exist_by_double_list(target, element):
    """
        判断二维列表中是否存在某个元素
    :param target:目标列表
    :param element:需要判断的元素
    :return:bool类型，是否存在
    """
    for index, line in enumerate(target):
        if element in line:
            return True, index
    return False, -1


def GetInitSeg(data, th):
    """获取初始化分割"""
    Features = GetFeatures(data)  # 获取相似性矩阵
    graph = Graph(kind='Undinetwork')
    VertexList = CreateVertex(data.shape[0])  # 创建节点
    EdgeList = CreateEdges(VertexList, Features)  # 创建

    graph.CreateGraph(vertex_list=VertexList,
                      edge_list=EdgeList)
    mst = graph.Kruskal()
    K = mst[int(len(mst) * th)][-1]
    NewEdge = [x for x in mst if x[-1] < K]  # 分割后的图
    VertexList = CreateVertex(data.shape[0])  # 创建节点
    temp = [[x] for x in VertexList]
    for edge in NewEdge:
        edge1, edge2 = edge[:2]
        jx, idx = is_exist_by_double_list(temp, edge1)  # 定位要移动到某个元素上
        jy, idy = is_exist_by_double_list(temp, edge2)  # 待移动的元素
        temp[idx].extend(temp[idy])
        temp.pop(idy)
    return temp


class MyNet(nn.Module):
    def __init__(self, inp_dim, mod_dim1, mod_dim2):
        super(MyNet, self).__init__()

        self.seq = nn.Sequential(
            nn.Conv1d(inp_dim, mod_dim1, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm1d(mod_dim1),
            nn.ReLU(inplace=True),
            nn.Conv1d(mod_dim1, mod_dim2, kernel_size=1, stride=1, padding=0),
            nn.BatchNorm1d(mod_dim2),
            nn.ReLU(inplace=True),
            nn.Conv1d(mod_dim2, mod_dim1, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm1d(mod_dim1),
            nn.ReLU(inplace=True),
            nn.Conv1d(mod_dim1, mod_dim2, kernel_size=1, stride=1, padding=0),
            nn.BatchNorm1d(mod_dim2),
        )

    def forward(self, x):
        return self.seq(x)


def kmeansclu(data, n_clusters):
    kmeans = KMeans(n_clusters=n_clusters, random_state=0, algorithm="auto")
    kmeans.fit(data)
    cluList = kmeans.labels_
    res = trans_id2i(cluList)
    return res

def trans_id2i(data):
    Dict = {}
    for id, i in enumerate(data):
        if i not in Dict:
            Dict[i] = []
        Dict[i].append(id)
    res = []
    for k, v in Dict.items():
        res.append(v)
    return res

def filter_noise(InitSeg, VertexList, windows_size):
    if windows_size % 2 != 1:
        raise ValueError('windows_size needs to be set to an odd number')
    """过滤噪音"""

    center_id = int(windows_size / 2)
    result = copy.deepcopy(InitSeg)

    def find_cluster(date):
        """查找当前日期为第几个簇"""
        for i, data in enumerate(result):
            if date in data:
                return i

    def list_dict(arr):
        """将列表按元素出现个数进行统计"""
        result = {}
        for i in set(arr):
            result[i] = arr.count(i)
        return result

    def getDict_max(dict):
        """最大值有多个，获取所有最大值的键"""
        max_list = []
        max_value = max(dict.values())  # 利用 max 函数先找到最大 value
        for m, n in dict.items():  # 遍历字典一遍找对应的 key 值
            if n == max_value:
                max_list.append(m)
        return max_list

    for i in range(len(VertexList) - windows_size + 1):
        window = [find_cluster(date) for date in VertexList[i:i + windows_size]]
        # print(window)
        e_dict, id, label = list_dict(window), window[center_id], VertexList[i + center_id]
        # print(window, e_dict, id, label, getDict_max(e_dict))
        if id not in getDict_max(e_dict) and (window[-1] == window[0] or window[-2] == window[0] or window[-1] == window[1] or window[-2] == window[1]):
            # print('过滤')
            result[getDict_max(e_dict)[0]].append(label)
            result[id].remove(label)

    result = [i for i in result if len(i)!=0]
    return result

if __name__ == '__main__':
    data = np.load('../data.npy')
    # InitSeg = kmeansclu(data, 5)   # kmeans聚类
    InitSeg = GetInitSeg(data, 0.7)  # 获取初始分割结果
    # InitSeg = filter_noise(InitSeg, CreateVertex(data.shape[0]), 3)
    print('初始分割出了{}个簇'.format(len(InitSeg)))
    for i in range(len(InitSeg)):
        plt.scatter(InitSeg[i], [i for j in InitSeg[i]])
    plt.show()
    X = data[:, :, np.newaxis].T
    X = torch.from_numpy(X).to(device)
    dims = 8
    model = MyNet(inp_dim=64, mod_dim1=16, mod_dim2=dims).to(device)
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=5e-2, momentum=0.9)
    model.train()

    for batch_idx in range(100):
        '''forward'''
        optimizer.zero_grad()
        output = model(X)[0]
        output = output.permute(1, 0).view(-1, dims)
        target = torch.argmax(output, 1)
        # print(target)
        im_target = target.data.cpu().numpy()
        '''refine'''
        for inds in InitSeg:
            sort_id = np.argsort(inds)
            # print(sort_id)
            # print(im_target[np.array(inds)[sort_id.astype('int')]])
            u_labels, hist = np.unique(im_target[inds], return_counts=True)
            # print(sort_id, inds, im_target[inds], u_labels, hist)
            im_target[inds] = u_labels[np.argmax(hist)]
            # print(im_target[inds])
        '''backward'''
        target = torch.from_numpy(im_target)
        target = target.to(device)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        '''show image'''
        un_label, lab_inverse = np.unique(
            im_target,
            return_inverse=True,
        )
        if len(un_label) < 3:
            break
        im_target = trans_id2i(im_target)
        # print(im_target)
        if batch_idx == 99 or batch_idx == 0:
            for i in range(len(im_target)):
                plt.scatter(im_target[i], [i for j in im_target[i]])
            plt.show()

    # for i in range(len(im_target)):
    #     plt.scatter(target[i], [i for j in target[i]])
    # plt.show()
    # print(InitSeg)
